from parser1 import args
import numpy as np
import pickle
import scipy.sparse as sp
from scipy.sparse import csr_matrix
import torch
import torch.nn as nn
import torch.nn.functional as F


def matrix_to_tensor(cur_matrix):
    if type(cur_matrix) != sp.coo_matrix:
        cur_matrix = cur_matrix.tocoo()  #
    indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64))  #
    values = torch.from_numpy(cur_matrix.data)  #
    shape = torch.Size(cur_matrix.shape)

    return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda()  #

def graph_drop(graph,keepRate):
    vals = graph._values()
    idxs = graph._indices()
    edgeNum = vals.size()
    mask = ((torch.rand(edgeNum) + keepRate).floor()).type(torch.bool)
    newVals = vals[mask] / keepRate
    newIdxs = idxs[:, mask]
    return torch.sparse.FloatTensor(newIdxs, newVals, graph.shape)


class Light_GCN(nn.Module):
    def __init__(self,emb_size):
        super(Light_GCN,self).__init__()
        self.n_layers = 2
        self.emb_size = emb_size
        self.ui_graph = pickle.load(open(args.data_path + args.dataset + '/train_mat', 'rb'))
        self.n_user = self.ui_graph.shape[0]
        self.n_item = self.ui_graph.shape[1]
        #print("graph:",self.n_item)
        self.user_emb = nn.Embedding(self.n_user,self.emb_size)
        self.item_emb = nn.Embedding(self.n_item,self.emb_size)
        nn.init.xavier_uniform_(self.item_emb.weight)
        nn.init.xavier_uniform_(self.user_emb.weight)
        A = sp.dok_matrix((self.n_user+self.n_item,self.n_user+self.n_item),dtype=np.float32)
        A = A.tolil()
        R = self.ui_graph.todok()
        A[:self.n_user,self.n_user:] = R
        A[self.n_user:,:self.n_user] = R.T
        sumArr = (A>0).sum(axis=1)
        diag = np.array(sumArr.flatten())[0]+1e-7
        diag = np.power(diag,-0.5)
        D = sp.diags(diag)
        L = D*A*D
        self.L = sp.coo_matrix(L)
        #print(self.ui_graph.shape)

    def forward(self):
        all_emb = torch.cat([self.user_emb.weight,self.item_emb.weight])
        emb_lsit = [all_emb]
        for layer in range(self.n_layers):
            all_emb = torch.sparse.mm(graph_drop(matrix_to_tensor(self.L),args.keep_rate),all_emb)
            emb_lsit.append(all_emb)
        all_emb = torch.mean(torch.stack(emb_lsit,dim=1),dim=1)
        user_all_embeddings, item_all_embeddings = torch.split(all_emb, [self.n_user, self.n_item])
        return user_all_embeddings,item_all_embeddings